{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2\n",
    "\n",
    "from IPython import display\n",
    "\n",
    "from utils import Logger\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd.variable import Variable\n",
    "from torchvision import transforms, datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_FOLDER = './torch_data/VGAN/MNIST'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mnist_data():\n",
    "    compose = transforms.Compose(\n",
    "        [transforms.ToTensor(),\n",
    "         transforms.Normalize([0.5], [0.5])\n",
    "        ])\n",
    "    out_dir = '{}/dataset'.format(DATA_FOLDER)\n",
    "    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "data = mnist_data()\n",
    "# Create loader with data, so that we can iterate over it\n",
    "data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)\n",
    "# Num batches\n",
    "num_batches = len(data_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DiscriminatorNet(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    A three hidden-layer discriminative neural network\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super(DiscriminatorNet, self).__init__()\n",
    "        n_features = 784\n",
    "        n_out = 1\n",
    "        \n",
    "        self.hidden0 = nn.Sequential( \n",
    "            nn.Linear(n_features, 1024),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Dropout(0.3)\n",
    "        )\n",
    "        self.hidden1 = nn.Sequential(\n",
    "            nn.Linear(1024, 512),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Dropout(0.3)\n",
    "        )\n",
    "        self.hidden2 = nn.Sequential(\n",
    "            nn.Linear(512, 256),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Dropout(0.3)\n",
    "        )\n",
    "        self.out = nn.Sequential(\n",
    "            torch.nn.Linear(256, n_out),\n",
    "            torch.nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.hidden0(x)\n",
    "        x = self.hidden1(x)\n",
    "        x = self.hidden2(x)\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "    \n",
    "def images_to_vectors(images):\n",
    "    return images.view(images.size(0), 784)\n",
    "\n",
    "def vectors_to_images(vectors):\n",
    "    return vectors.view(vectors.size(0), 1, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GeneratorNet(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    A three hidden-layer generative neural network\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super(GeneratorNet, self).__init__()\n",
    "        n_features = 100\n",
    "        n_out = 784\n",
    "        \n",
    "        self.hidden0 = nn.Sequential(\n",
    "            nn.Linear(n_features, 256),\n",
    "            nn.LeakyReLU(0.2)\n",
    "        )\n",
    "        self.hidden1 = nn.Sequential(            \n",
    "            nn.Linear(256, 512),\n",
    "            nn.LeakyReLU(0.2)\n",
    "        )\n",
    "        self.hidden2 = nn.Sequential(\n",
    "            nn.Linear(512, 1024),\n",
    "            nn.LeakyReLU(0.2)\n",
    "        )\n",
    "        \n",
    "        self.out = nn.Sequential(\n",
    "            nn.Linear(1024, n_out),\n",
    "            nn.Tanh()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.hidden0(x)\n",
    "        x = self.hidden1(x)\n",
    "        x = self.hidden2(x)\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "    \n",
    "# Noise\n",
    "def noise(size):\n",
    "    n = Variable(torch.randn(size, 100))\n",
    "    if torch.cuda.is_available(): return n.cuda() \n",
    "    return n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "discriminator = DiscriminatorNet()\n",
    "generator = GeneratorNet()\n",
    "if torch.cuda.is_available():\n",
    "    discriminator.cuda()\n",
    "    generator.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimizers\n",
    "d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)\n",
    "g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)\n",
    "\n",
    "# Loss function\n",
    "loss = nn.BCELoss()\n",
    "\n",
    "# Number of steps to apply to the discriminator\n",
    "d_steps = 1  # In Goodfellow et. al 2014 this variable is assigned to 1\n",
    "# Number of epochs\n",
    "num_epochs = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def real_data_target(size):\n",
    "    '''\n",
    "    Tensor containing ones, with shape = size\n",
    "    '''\n",
    "    data = Variable(torch.ones(size, 1))\n",
    "    if torch.cuda.is_available(): return data.cuda()\n",
    "    return data\n",
    "\n",
    "def fake_data_target(size):\n",
    "    '''\n",
    "    Tensor containing zeros, with shape = size\n",
    "    '''\n",
    "    data = Variable(torch.zeros(size, 1))\n",
    "    if torch.cuda.is_available(): return data.cuda()\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_discriminator(optimizer, real_data, fake_data):\n",
    "    # Reset gradients\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # 1.1 Train on Real Data\n",
    "    prediction_real = discriminator(real_data)\n",
    "    # Calculate error and backpropagate\n",
    "    error_real = loss(prediction_real, real_data_target(real_data.size(0)))\n",
    "    error_real.backward()\n",
    "\n",
    "    # 1.2 Train on Fake Data\n",
    "    prediction_fake = discriminator(fake_data)\n",
    "    # Calculate error and backpropagate\n",
    "    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))\n",
    "    error_fake.backward()\n",
    "    \n",
    "    # 1.3 Update weights with gradients\n",
    "    optimizer.step()\n",
    "    \n",
    "    # Return error\n",
    "    return error_real + error_fake, prediction_real, prediction_fake\n",
    "\n",
    "def train_generator(optimizer, fake_data):\n",
    "    # 2. Train Generator\n",
    "    # Reset gradients\n",
    "    optimizer.zero_grad()\n",
    "    # Sample noise and generate fake data\n",
    "    prediction = discriminator(fake_data)\n",
    "    # Calculate error and backpropagate\n",
    "    error = loss(prediction, real_data_target(prediction.size(0)))\n",
    "    error.backward()\n",
    "    # Update weights with gradients\n",
    "    optimizer.step()\n",
    "    # Return error\n",
    "    return error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Samples for Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 16\n",
    "test_noise = noise(num_test_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "logger = Logger(model_name='VGAN', data_name='MNIST')\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for n_batch, (real_batch,_) in enumerate(data_loader):\n",
    "\n",
    "        # 1. Train Discriminator\n",
    "        real_data = Variable(images_to_vectors(real_batch))\n",
    "        if torch.cuda.is_available(): real_data = real_data.cuda()\n",
    "        # Generate fake data\n",
    "        fake_data = generator(noise(real_data.size(0))).detach()\n",
    "        # Train D\n",
    "        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,\n",
    "                                                                real_data, fake_data)\n",
    "\n",
    "        # 2. Train Generator\n",
    "        # Generate fake data\n",
    "        fake_data = generator(noise(real_batch.size(0)))\n",
    "        # Train G\n",
    "        g_error = train_generator(g_optimizer, fake_data)\n",
    "        # Log error\n",
    "        logger.log(d_error, g_error, epoch, n_batch, num_batches)\n",
    "\n",
    "        # Display Progress\n",
    "        if (n_batch) % 100 == 0:\n",
    "            display.clear_output(True)\n",
    "            # Display Images\n",
    "            test_images = vectors_to_images(generator(test_noise)).data.cpu()\n",
    "            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);\n",
    "            # Display status Logs\n",
    "            logger.display_status(\n",
    "                epoch, num_epochs, n_batch, num_batches,\n",
    "                d_error, g_error, d_pred_real, d_pred_fake\n",
    "            )\n",
    "        # Model Checkpoints\n",
    "        logger.save_models(generator, discriminator, epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
